package cn.piflow.rpc

import java.lang.reflect.{InvocationHandler, Method}
import javax.net.ssl.SSLContext
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import cn.piflow.util.Logging
import org.apache.http.client.entity.EntityBuilder
import org.apache.http.client.methods.HttpPost
import org.apache.http.config.RegistryBuilder
import org.apache.http.conn.socket.{ConnectionSocketFactory, PlainConnectionSocketFactory}
import org.apache.http.conn.ssl.SSLConnectionSocketFactory
import org.apache.http.entity.ContentType
import org.apache.http.impl.client.HttpClients
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager
import org.apache.spark.SparkConf
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}

/**
	* Created by bluejoe on 2017/10/26.
	*/

trait ServiceContainer {
	def serve(name: String, service: Object): this.type;
}

class HttpRpcServer extends ServiceContainer with Logging {
	private var server: Server = null;

	val stockServices = collection.mutable.Map[String, Object]();

	val serializer = SerializerFactory.DEFAULT.getSerializer("kryo");

	def serve(name: String, service: Object): this.type = {
		stockServices += (name -> service);
		this;
	}

	def serve(clazz: Class[_], service: Object): this.type = {
		serve(clazz.getName, service);
	}

	def start(httpPort: Int, httpServletPath: String) = {
		server = new Server(httpPort);
		val context = new ServletContextHandler(ServletContextHandler.SESSIONS);
		context.setContextPath("/");
		server.setHandler(context);
		//add servlet
		val map = stockServices.toMap;
		context.addServlet(new ServletHolder(new HttpRpcServlet(map, serializer)), httpServletPath);
		server.start();

		logger.info(s"start HTTP RPC service on http://localhost:$httpPort$httpServletPath");
		logger.debug(s"stock beans: $map");
	}
}

case class ReturnedValue(value: Option[Object], exception: Option[Throwable]) {

}

class HttpRpcServlet(stockBeans: Map[String, Object], serializer: Serializer) extends HttpServlet with Logging {
	override def doPost(request: HttpServletRequest, response: HttpServletResponse) {
		val methodCall = serializer.newInstance().deserializeStream(request.getInputStream).readObject().asInstanceOf[MethodCall];
		logger.debug(s"received method call: $methodCall");

		val returned = try {
			ReturnedValue(Some(methodCall.invoke(stockBeans(methodCall.beanName))), None);
		}
		catch {
			case e: Throwable =>
				e.printStackTrace();
				ReturnedValue(None, Some(e));
		}

		logger.debug(s"responding to client: $returned");
		val sos = serializer.newInstance().serializeStream(response.getOutputStream);
		sos.writeObject(returned);
		sos.flush();
	}

	override def destroy() {
	}
}

case class MethodCall(beanName: String, methodName: String, paramTypes: List[Class[_]], params: List[Object]) {
	def invoke(bean: Object): Object = {
		val method = bean.getClass().getMethod(methodName, paramTypes: _*);
		method.invoke(bean, params: _*);
	}
}

class HttpRpcClient(httpServletUrl: String) extends Logging {
	val sslsf = new SSLConnectionSocketFactory(SSLContext.getDefault());
	val socketFactoryRegistry = RegistryBuilder.create[ConnectionSocketFactory]()
		.register("https", sslsf)
		.register("http", new PlainConnectionSocketFactory())
		.build();

	val connectionManager = new PoolingHttpClientConnectionManager(socketFactoryRegistry);
	connectionManager.setMaxTotal(200);
	connectionManager.setDefaultMaxPerRoute(20);

	val serializer = SerializerFactory.DEFAULT.getSerializer("kryo");

	private def getClient = {
		val client = HttpClients.custom()
			.setConnectionManager(connectionManager)
			.build();

		client;
	}

	def createProxy[T](beanName: String, interface: Class[_]): T = {
		java.lang.reflect.Proxy.newProxyInstance(this.getClass.getClassLoader(),
			Array(interface), new InvocationHandler {
				override def invoke(proxy: Object, method: Method, args: Array[Object]): Object = {
					val methodCall = MethodCall(beanName, method.getName,
						method.getParameterTypes.toList,
						if (args == null) {
							List();
						} else {
							args.toList;
						});
					doRemoteMethodCall(methodCall);
				}
			}).asInstanceOf[T];
	}

	private def doRemoteMethodCall(methodCall: MethodCall): Object = {
		val post = new HttpPost(httpServletUrl);
		val bytes = serializer.newInstance().serialize(methodCall).array();
		val builder = EntityBuilder.create()
			.setBinary(bytes)
			.setContentType(ContentType.APPLICATION_OCTET_STREAM);

		val entity = builder.build();
		post.setEntity(entity);
		val client = getClient;

		logger.debug(s"sending method call: $methodCall");
		val resp = client.execute(post);

		//server side exception
		if (resp.getStatusLine.getStatusCode != 200) {
			throw new HttpRpcServerSideException(resp.getStatusLine.getReasonPhrase);
		}

		val is = resp.getEntity.getContent;
		val returned: ReturnedValue = serializer.newInstance().deserializeStream(is).readObject();
		is.close();

		logger.debug(s"response from server: $returned");
		returned match {
			case ReturnedValue(Some(v), None) => v.asInstanceOf[Object];
			case ReturnedValue(None, Some(ex)) => {
				throw new HttpRpcServerSideException(resp.getStatusLine.getReasonPhrase, ex)
			};
		}
	}
}

object SerializerFactory {
	val DEFAULT: SerializerFactory = new SerializerFactory {
		override def getSerializer(serializerName: String): Serializer = {
			serializerName.toLowerCase() match {
				case "kryo" ⇒
					new KryoSerializer(new SparkConf());
				case "java" ⇒
					new JavaSerializer(new SparkConf());
				case _ ⇒ throw new InvalidSerializerNameException(serializerName);
			}
		}
	}
}

trait SerializerFactory {
	def getSerializer(serializerName: String): Serializer;
}

class HttpRpcServerSideException(msg: String, cause: Throwable = null) extends RuntimeException(msg, cause) {
}

class InvalidSerializerNameException(serializerName: String)
	extends RuntimeException(s"invalid serializer name: $serializerName") {
}
